"""
Phase 5: Forward Projections — Who's At Risk?
===============================================
Project fiscal dominance risk using UN WPP demographic projections.

Input:  fiscal_dominance/data/processed/fiscal_panel.csv
        multilateral/data/processed/full_panel.csv (for Z polynomials through 2060)
Output: fiscal_dominance/output/tables/phase5_*.csv
"""

import sys
import pandas as pd
import numpy as np
from pathlib import Path

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
FD_DIR = PROJECT_DIR / "fiscal_dominance"
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
PROCESSED_DIR = FD_DIR / "data" / "processed"
TABLE_DIR = FD_DIR / "output" / "tables"

sys.path.insert(0, str(PROJECT_DIR / "multilateral" / "src"))
from model import PanelGLS


def fit_and_report(y, X, entity_ids, time_ids, feature_names, label):
    """Fit PanelGLS and return summary DataFrame."""
    model = PanelGLS()
    model.fit(y, X, entity_ids, time_ids)
    print(f"\n  {label}: N={model.n_obs:,}, R²={model.r_squared:.4f}")
    model.summary(feature_names=feature_names)
    return model


def main():
    print("=" * 70)
    print("PHASE 5: Forward Projections — Who's At Risk?")
    print("=" * 70)

    # Load current panel
    df = pd.read_csv(PROCESSED_DIR / "fiscal_panel.csv")
    print(f"Fiscal panel: {len(df):,} obs, {df['iso3'].nunique()} countries")

    # Load full_panel for Z polynomials (includes projections beyond 2024)
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")
    print(f"Full panel (with projections): {fp['year'].min()}-{fp['year'].max()}")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    controls = ['fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag']
    controls = [c for c in controls if c in df.columns]

    # =================================================================
    # 1. Estimate projection models on historical data
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Estimating Projection Models")
    print("=" * 70)

    # Model A: Z -> r-g
    vars_rg = demo_vars + controls
    est_rg = df.dropna(subset=['r_minus_g'] + vars_rg).copy()
    m_rg = None
    if len(est_rg) >= 200:
        m_rg = fit_and_report(
            est_rg['r_minus_g'].values, est_rg[vars_rg].values,
            est_rg['iso3'].values, est_rg['year'].values,
            vars_rg, "Projection Model A: Z -> r-g"
        )

    # Model B: Z -> fiscal stress
    m_stress = None
    if 'fiscal_stress' not in df.columns:
        # Reconstruct fiscal stress if not available
        regimes = PROCESSED_DIR / "fiscal_regimes.csv"
        if regimes.exists():
            regime_df = pd.read_csv(regimes)
            df = df.merge(regime_df[['iso3', 'year', 'fiscal_stress']],
                          on=['iso3', 'year'], how='left', suffixes=('', '_regime'))
            if 'fiscal_stress_regime' in df.columns:
                df['fiscal_stress'] = df['fiscal_stress'].fillna(df['fiscal_stress_regime'])

    if 'fiscal_stress' in df.columns:
        vars_stress = demo_vars + controls
        est_stress = df.dropna(subset=['fiscal_stress'] + vars_stress).copy()
        if len(est_stress) >= 200:
            m_stress = fit_and_report(
                est_stress['fiscal_stress'].values, est_stress[vars_stress].values,
                est_stress['iso3'].values, est_stress['year'].values,
                vars_stress, "Projection Model B: Z -> Fiscal Stress"
            )

    # Model C: Z -> primary balance
    vars_pb = demo_vars + ['debt_lag'] + controls
    vars_pb = [v for v in vars_pb if v in df.columns]
    est_pb = df.dropna(subset=['primary_bal_gdp'] + vars_pb).copy()
    m_pb = None
    if len(est_pb) >= 200:
        m_pb = fit_and_report(
            est_pb['primary_bal_gdp'].values, est_pb[vars_pb].values,
            est_pb['iso3'].values, est_pb['year'].values,
            vars_pb, "Projection Model C: Z -> Primary Balance"
        )

    # =================================================================
    # 2. Project using future Z polynomials
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Country Risk Projections (2000-2060)")
    print("=" * 70)

    focus_countries = [
        'JPN', 'DEU', 'USA', 'KOR', 'CHN', 'IND', 'IDN', 'NGA',
        'BRA', 'GBR', 'AUS', 'FRA', 'ITA', 'ESP', 'ZAF',
        'SAU', 'MEX', 'THA', 'POL', 'TUR'
    ]
    proj_years = list(range(2000, 2065, 5))

    # Sample means for controls (fallback when country-specific values missing)
    control_means = {c: df[c].mean() for c in controls if c in df.columns}

    projection_rows = []

    for iso3 in focus_countries:
        cdf = fp[fp['iso3'] == iso3].copy()
        if len(cdf) == 0:
            continue

        # Latest observed controls (hold constant for projection)
        latest = df[(df['iso3'] == iso3)].sort_values('year').tail(1)
        if len(latest) == 0:
            continue
        # Use country-specific value if available, else sample mean
        latest_controls = {}
        for c in controls:
            if c in latest.columns and latest[c].notna().values[0]:
                latest_controls[c] = latest[c].values[0]
            else:
                latest_controls[c] = control_means.get(c, 0)

        latest_debt_row = df[(df['iso3'] == iso3) & df['govt_debt_gdp'].notna()].sort_values('year').tail(1)
        latest_debt = latest_debt_row['govt_debt_gdp'].values[0] if len(latest_debt_row) > 0 else np.nan

        for year in proj_years:
            yr = cdf[cdf['year'] == year]
            if len(yr) == 0 or yr[demo_vars].isna().any(axis=1).values[0]:
                continue

            z_vals = {zv: yr[zv].values[0] for zv in demo_vars}
            oadr = yr['old_dep'].values[0] if 'old_dep' in yr.columns and yr['old_dep'].notna().values[0] else np.nan

            row = {
                'iso3': iso3,
                'year': year,
                'old_dep': oadr,
                'Z_1': z_vals['Z_1'],
            }

            # Project r-g
            if m_rg is not None:
                x_vec = [z_vals.get(v, latest_controls.get(v, 0)) for v in vars_rg]
                if not any(np.isnan(x) for x in x_vec):
                    rg_pred = m_rg.constant + np.dot(m_rg.beta, x_vec)
                    row['proj_r_minus_g'] = rg_pred

            # Project fiscal stress
            if m_stress is not None:
                x_vec = [z_vals.get(v, latest_controls.get(v, 0)) for v in vars_stress]
                if not any(np.isnan(x) for x in x_vec):
                    stress_pred = m_stress.constant + np.dot(m_stress.beta, x_vec)
                    row['proj_fiscal_stress'] = stress_pred

            # Required primary surplus = (r-g) * debt/GDP
            if 'proj_r_minus_g' in row and not np.isnan(latest_debt):
                row['required_primary_surplus'] = row['proj_r_minus_g'] / 100 * latest_debt

            projection_rows.append(row)

    proj = pd.DataFrame(projection_rows)
    if len(proj) > 0:
        proj.to_csv(TABLE_DIR / "phase5_country_projections.csv", index=False)
        print(f"Saved country projections: {len(proj)} rows")

        # Print r-g projection pivot
        if 'proj_r_minus_g' in proj.columns:
            pivot_rg = proj.pivot(index='iso3', columns='year', values='proj_r_minus_g')
            print(f"\n  Projected r-g by Demographics Alone (pp):")
            print(pivot_rg.to_string(float_format='%.2f'))

    # =================================================================
    # 3. "Who's At Risk?" — Top-20 ranking
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Who's At Risk? (Top-20 Ranking for 2040)")
    print("=" * 70)

    # Use all countries, not just focus
    risk_rows = []
    for iso3 in fp['iso3'].unique():
        yr2040 = fp[(fp['iso3'] == iso3) & (fp['year'] == 2040)]
        if len(yr2040) == 0 or yr2040[demo_vars].isna().any(axis=1).values[0]:
            continue

        z_vals = {zv: yr2040[zv].values[0] for zv in demo_vars}
        oadr = yr2040['old_dep'].values[0] if 'old_dep' in yr2040.columns and yr2040['old_dep'].notna().values[0] else np.nan

        # Latest debt
        latest = df[(df['iso3'] == iso3) & df['govt_debt_gdp'].notna()].sort_values('year').tail(1)
        latest_debt = latest['govt_debt_gdp'].values[0] if len(latest) > 0 else np.nan

        # Latest controls — fill NaN with sample means
        latest_c = df[(df['iso3'] == iso3)].sort_values('year').tail(1)
        if len(latest_c) == 0:
            continue
        ctrl_vals = {}
        for c in controls:
            if c in latest_c.columns and latest_c[c].notna().values[0]:
                ctrl_vals[c] = latest_c[c].values[0]
            else:
                ctrl_vals[c] = control_means.get(c, 0)

        row = {'iso3': iso3, 'old_dep_2040': oadr, 'latest_debt': latest_debt}

        if m_rg is not None:
            x_vec = [z_vals.get(v, ctrl_vals.get(v, 0)) for v in vars_rg]
            if not any(np.isnan(x) for x in x_vec):
                row['proj_rg_2040'] = m_rg.constant + np.dot(m_rg.beta, x_vec)

        if m_stress is not None:
            x_vec = [z_vals.get(v, ctrl_vals.get(v, 0)) for v in vars_stress]
            if not any(np.isnan(x) for x in x_vec):
                row['proj_stress_2040'] = m_stress.constant + np.dot(m_stress.beta, x_vec)

        risk_rows.append(row)

    risk_df = pd.DataFrame(risk_rows)
    if len(risk_df) > 0:
        # Composite risk score
        for col in ['proj_rg_2040', 'proj_stress_2040', 'latest_debt']:
            if col in risk_df.columns:
                risk_df[f'{col}_z'] = (risk_df[col] - risk_df[col].mean()) / risk_df[col].std()

        score_cols = [c for c in risk_df.columns if c.endswith('_z')]
        if score_cols:
            risk_df['risk_score'] = risk_df[score_cols].mean(axis=1)
            risk_df = risk_df.sort_values('risk_score', ascending=False)

            print(f"\n  Top-20 At-Risk Countries (by composite risk score):")
            display_cols = ['iso3', 'old_dep_2040', 'latest_debt', 'risk_score']
            if 'proj_rg_2040' in risk_df.columns:
                display_cols.insert(3, 'proj_rg_2040')
            print(risk_df[display_cols].head(20).to_string(index=False, float_format='%.2f'))

        risk_df.to_csv(TABLE_DIR / "phase5_risk_ranking.csv", index=False)

    # =================================================================
    # 4. KAOPEN scenarios (open vs closed)
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  KAOPEN Scenario: Open vs Closed")
    print("=" * 70)

    if m_rg is not None and 'kaopen' in vars_rg:
        kaopen_idx = vars_rg.index('kaopen')
        scenario_rows = []

        for iso3 in focus_countries[:10]:
            yr2040 = fp[(fp['iso3'] == iso3) & (fp['year'] == 2040)]
            if len(yr2040) == 0 or yr2040[demo_vars].isna().any(axis=1).values[0]:
                continue

            latest_c = df[(df['iso3'] == iso3)].sort_values('year').tail(1)
            if len(latest_c) == 0:
                continue
            base_controls = {}
            for c in controls:
                if c in latest_c.columns and latest_c[c].notna().values[0]:
                    base_controls[c] = latest_c[c].values[0]
                else:
                    base_controls[c] = control_means.get(c, 0)
            z_vals = {zv: yr2040[zv].values[0] for zv in demo_vars}

            for kaopen_val, label in [(0, 'Closed'), (1, 'Open')]:
                x_vec = []
                for v in vars_rg:
                    if v == 'kaopen':
                        x_vec.append(kaopen_val)
                    elif v in z_vals:
                        x_vec.append(z_vals[v])
                    else:
                        x_vec.append(base_controls.get(v, 0))
                rg_pred = m_rg.constant + np.dot(m_rg.beta, x_vec)
                scenario_rows.append({
                    'iso3': iso3,
                    'scenario': label,
                    'proj_rg_2040': rg_pred,
                })

        if scenario_rows:
            scenario_df = pd.DataFrame(scenario_rows)
            pivot = scenario_df.pivot(index='iso3', columns='scenario', values='proj_rg_2040')
            pivot['diff_open_minus_closed'] = pivot['Open'] - pivot['Closed']
            print(pivot.to_string(float_format='%.2f'))
            pivot.to_csv(TABLE_DIR / "phase5_kaopen_scenarios.csv")

    # =================================================================
    # 5. Debt sustainability: required primary surplus
    # =================================================================
    print(f"\n{'=' * 70}")
    print("  Required Primary Surplus = (r-g) * debt/GDP")
    print("=" * 70)

    if 'proj_r_minus_g' in proj.columns:
        for year_show in [2025, 2035, 2045]:
            yr_proj = proj[proj['year'] == year_show].copy()
            if len(yr_proj) == 0:
                continue
            # Get latest debt for each country
            for idx, row in yr_proj.iterrows():
                latest = df[(df['iso3'] == row['iso3']) & df['govt_debt_gdp'].notna()].sort_values('year').tail(1)
                if len(latest) > 0:
                    debt = latest['govt_debt_gdp'].values[0]
                    yr_proj.loc[idx, 'latest_debt'] = debt
                    yr_proj.loc[idx, 'required_ps'] = row['proj_r_minus_g'] / 100 * debt

            if 'required_ps' in yr_proj.columns:
                yr_proj = yr_proj.sort_values('required_ps', ascending=False)
                print(f"\n  {year_show} Required Primary Surplus (% GDP):")
                print(yr_proj[['iso3', 'old_dep', 'proj_r_minus_g', 'latest_debt', 'required_ps']]
                      .dropna().head(15).to_string(index=False, float_format='%.2f'))

    print(f"\n{'=' * 70}")
    print("Phase 5 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
